Your First GAN

Goal

In this notebook, you're going to create your first generative adversarial network (GAN) for this course! Specifically, you will build and train a GAN that can generate hand-written images of digits (0-9). You will be using PyTorch in this specialization, so if you're not familiar with this framework, you may find the PyTorch documentation useful. The hints will also often include links to relevant documentation.

Learning Objectives

  1. Build the generator and discriminator components of a GAN from scratch.
  2. Create generator and discriminator loss functions.
  3. Train your GAN and visualize the generated images.

Getting Started

You will begin by importing some useful packages and the dataset you will use to build and train your GAN. You are also provided with a visualizer function to help you investigate the images your GAN will create.

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for testing purposes, please do not change!

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

MNIST Dataset

The training images your discriminator will be using is from a dataset called MNIST. It contains 60,000 images of handwritten digits, from 0 to 9, like these:

MNIST Digits

You may notice that the images are quite pixelated -- this is because they are all only 28 x 28! The small size of its images makes MNIST ideal for simple training. Additionally, these images are also in black-and-white so only one dimension, or "color channel", is needed to represent them (more on this later in the course).

Tensor

You will represent the data using tensors. Tensors are a generalization of matrices: for example, a stack of three matrices with the amounts of red, green, and blue at different locations in a 64 x 64 pixel image is a tensor with the shape 3 x 64 x 64.

Tensors are easy to manipulate and supported by PyTorch, the machine learning library you will be using. Feel free to explore them more, but you can imagine these as multi-dimensional matrices or vectors!

Batches

While you could train your model after generating one image, it is extremely inefficient and leads to less stable training. In GANs, and in machine learning in general, you will process multiple images per training step. These are called batches.

This means that your generator will generate an entire batch of images and receive the discriminator's feedback on each before updating the model. The same goes for the discriminator, it will calculate its loss on the entire batch of generated images as well as on the reals before the model is updated.

Generator

The first step is to build the generator component.

You will start by creating a function to make a single layer/block for the generator's neural network. Each block should include a linear transformation to map to another shape, a batch normalization for stabilization, and finally a non-linear activation function (you use a ReLU here) so the output can be transformed in complex ways. You will learn more about activations and batch normalization later in the course.

In [2]:
# UNQ_C1 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_generator_block
def get_generator_block(input_dim, output_dim):
    '''
    Function for returning a block of the generator's neural network
    given input and output dimensions.
    Parameters:
        input_dim: the dimension of the input vector, a scalar
        output_dim: the dimension of the output vector, a scalar
    Returns:
        a generator neural network layer, with a linear transformation 
          followed by a batch normalization and then a relu activation
    '''
    return nn.Sequential(
        # Hint: Replace all of the "None" with the appropriate dimensions.
        # The documentation may be useful if you're less familiar with PyTorch:
        # https://pytorch.org/docs/stable/nn.html.
        #### START CODE HERE ####
        nn.Linear(input_dim, output_dim),
        nn.BatchNorm1d(output_dim),
        nn.ReLU(inplace=True),
        #### END CODE HERE ####
    )
In [3]:
# Verify the generator block function
def test_gen_block(in_features, out_features, num_test=1000):
    block = get_generator_block(in_features, out_features)

    # Check the three parts
    assert len(block) == 3
    assert type(block[0]) == nn.Linear
    assert type(block[1]) == nn.BatchNorm1d
    assert type(block[2]) == nn.ReLU
    
    # Check the output shape
    test_input = torch.randn(num_test, in_features)
    test_output = block(test_input)
    assert tuple(test_output.shape) == (num_test, out_features)
    assert test_output.std() > 0.55
    assert test_output.std() < 0.65

test_gen_block(25, 12)
test_gen_block(15, 28)
print("Success!")
Success!

Now you can build the generator class. It will take 3 values:

  • The noise vector dimension
  • The image dimension
  • The initial hidden dimension

Using these values, the generator will build a neural network with 5 layers/blocks. Beginning with the noise vector, the generator will apply non-linear transformations via the block function until the tensor is mapped to the size of the image to be outputted (the same size as the real images from MNIST). You will need to fill in the code for final layer since it is different than the others. The final layer does not need a normalization or activation function, but does need to be scaled with a sigmoid function.

Finally, you are given a forward pass function that takes in a noise vector and generates an image of the output dimension using your neural network.

Optional hints for Generator 1. The output size of the final linear transformation should be im_dim, but remember you need to scale the outputs between 0 and 1 using the sigmoid function. 2. [nn.Linear](https://pytorch.org/docs/master/generated/torch.nn.Linear.html) and [nn.Sigmoid](https://pytorch.org/docs/master/generated/torch.nn.Sigmoid.html) will be useful here.
In [4]:
# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: Generator
class Generator(nn.Module):
    '''
    Generator Class
    Values:
        z_dim: the dimension of the noise vector, a scalar
        im_dim: the dimension of the images, fitted for the dataset used, a scalar
          (MNIST images are 28 x 28 = 784 so that is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):
        super(Generator, self).__init__()
        # Build the neural network
        self.gen = nn.Sequential(
            get_generator_block(z_dim, hidden_dim),
            get_generator_block(hidden_dim, hidden_dim * 2),
            get_generator_block(hidden_dim * 2, hidden_dim * 4),
            get_generator_block(hidden_dim * 4, hidden_dim * 8),
            # There is a dropdown with hints if you need them! 
            #### START CODE HERE ####
            nn.Linear(hidden_dim * 8, im_dim),
            nn.Sigmoid()
            #### END CODE HERE ####
        )
    def forward(self, noise):
        '''
        Function for completing a forward pass of the generator: Given a noise tensor, 
        returns generated images.
        Parameters:
            noise: a noise tensor with dimensions (n_samples, z_dim)
        '''
        return self.gen(noise)
    
    # Needed for grading
    def get_gen(self):
        '''
        Returns:
            the sequential model
        '''
        return self.gen
In [5]:
# Verify the generator class
def test_generator(z_dim, im_dim, hidden_dim, num_test=10000):
    gen = Generator(z_dim, im_dim, hidden_dim).get_gen()
    
    # Check there are six modules in the sequential part
    assert len(gen) == 6
    test_input = torch.randn(num_test, z_dim)
    test_output = gen(test_input)

    # Check that the output shape is correct
    assert tuple(test_output.shape) == (num_test, im_dim)
    assert test_output.max() < 1, "Make sure to use a sigmoid"
    assert test_output.min() > 0, "Make sure to use a sigmoid"
    assert test_output.std() > 0.05, "Don't use batchnorm here"
    assert test_output.std() < 0.15, "Don't use batchnorm here"

test_generator(5, 10, 20)
test_generator(20, 8, 24)
print("Success!")
Success!

Noise

To be able to use your generator, you will need to be able to create noise vectors. The noise vector z has the important role of making sure the images generated from the same class don't all look the same -- think of it as a random seed. You will generate it randomly using PyTorch by sampling random numbers from the normal distribution. Since multiple images will be processed per pass, you will generate all the noise vectors at once.

Note that whenever you create a new tensor using torch.ones, torch.zeros, or torch.randn, you either need to create it on the target device, e.g. torch.ones(3, 3, device=device), or move it onto the target device using torch.ones(3, 3).to(device). You do not need to do this if you're creating a tensor by manipulating another tensor or by using a variation that defaults the device to the input, such as torch.ones_like. In general, use torch.ones_like and torch.zeros_like instead of torch.ones or torch.zeros where possible.

Optional hint for get_noise 1. You will probably find [torch.randn](https://pytorch.org/docs/master/generated/torch.randn.html) useful here.
In [6]:
# UNQ_C3 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_noise
def get_noise(n_samples, z_dim, device='cpu'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, z_dim),
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    '''
    # NOTE: To use this on GPU with device='cuda', make sure to pass the device 
    # argument to the function you use to generate the noise.
    #### START CODE HERE ####
    return torch.randn(n_samples, z_dim, device = device)
    #### END CODE HERE ####
In [7]:
# Verify the noise vector function
def test_get_noise(n_samples, z_dim, device='cpu'):
    noise = get_noise(n_samples, z_dim, device)
    
    # Make sure a normal distribution was used
    assert tuple(noise.shape) == (n_samples, z_dim)
    assert torch.abs(noise.std() - torch.tensor(1.0)) < 0.01
    assert str(noise.device).startswith(device)

test_get_noise(1000, 100, 'cpu')
if torch.cuda.is_available():
    test_get_noise(1000, 32, 'cuda')
print("Success!")
Success!

Discriminator

The second component that you need to construct is the discriminator. As with the generator component, you will start by creating a function that builds a neural network block for the discriminator.

Note: You use leaky ReLUs to prevent the "dying ReLU" problem, which refers to the phenomenon where the parameters stop changing due to consistently negative values passed to a ReLU, which result in a zero gradient. You will learn more about this in the following lectures!

REctified Linear Unit (ReLU) Leaky ReLU
In [8]:
# UNQ_C4 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_discriminator_block
def get_discriminator_block(input_dim, output_dim):
    '''
    Discriminator Block
    Function for returning a neural network of the discriminator given input and output dimensions.
    Parameters:
        input_dim: the dimension of the input vector, a scalar
        output_dim: the dimension of the output vector, a scalar
    Returns:
        a discriminator neural network layer, with a linear transformation 
          followed by an nn.LeakyReLU activation with negative slope of 0.2 
          (https://pytorch.org/docs/master/generated/torch.nn.LeakyReLU.html)
    '''
    return nn.Sequential(
        #### START CODE HERE ####
        nn.Linear(input_dim, output_dim),
        nn.LeakyReLU(0.2, inplace=True), 
        #### END CODE HERE ####
    )
In [9]:
# Verify the discriminator block function
def test_disc_block(in_features, out_features, num_test=10000):
    block = get_discriminator_block(in_features, out_features)

    # Check there are two parts
    assert len(block) == 2
    test_input = torch.randn(num_test, in_features)
    test_output = block(test_input)

    # Check that the shape is right
    assert tuple(test_output.shape) == (num_test, out_features)
    
    # Check that the LeakyReLU slope is about 0.2
    assert -test_output.min() / test_output.max() > 0.1
    assert -test_output.min() / test_output.max() < 0.3
    assert test_output.std() > 0.3
    assert test_output.std() < 0.5

test_disc_block(25, 12)
test_disc_block(15, 28)
print("Success!")
Success!

Now you can use these blocks to make a discriminator! The discriminator class holds 2 values:

  • The image dimension
  • The hidden dimension

The discriminator will build a neural network with 4 layers. It will start with the image tensor and transform it until it returns a single number (1-dimension tensor) output. This output classifies whether an image is fake or real. Note that you do not need a sigmoid after the output layer since it is included in the loss function. Finally, to use your discrimator's neural network you are given a forward pass function that takes in an image tensor to be classified.

In [10]:
# UNQ_C5 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: Discriminator
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
        im_dim: the dimension of the images, fitted for the dataset used, a scalar
            (MNIST images are 28x28 = 784 so that is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, im_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            get_discriminator_block(im_dim, hidden_dim * 4),
            get_discriminator_block(hidden_dim * 4, hidden_dim * 2),
            get_discriminator_block(hidden_dim * 2, hidden_dim),
            # Hint: You want to transform the final output into a single value,
            #       so add one more linear map.
            #### START CODE HERE ####
            nn.Linear(hidden_dim, 1)
            #### END CODE HERE ####
        )

    def forward(self, image):
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_dim)
        '''
        return self.disc(image)
    
    # Needed for grading
    def get_disc(self):
        '''
        Returns:
            the sequential model
        '''
        return self.disc
In [11]:
# Verify the discriminator class
def test_discriminator(z_dim, hidden_dim, num_test=100):
    
    disc = Discriminator(z_dim, hidden_dim).get_disc()

    # Check there are three parts
    assert len(disc) == 4

    # Check the linear layer is correct
    test_input = torch.randn(num_test, z_dim)
    test_output = disc(test_input)
    assert tuple(test_output.shape) == (num_test, 1)
    
    # Make sure there's no sigmoid
    assert test_input.max() > 1
    assert test_input.min() < -1

test_discriminator(5, 10)
test_discriminator(20, 8)
print("Success!")
Success!

Training

Now you can put it all together! First, you will set your parameters:

  • criterion: the loss function
  • n_epochs: the number of times you iterate through the entire dataset when training
  • z_dim: the dimension of the noise vector
  • display_step: how often to display/visualize the images
  • batch_size: the number of images per forward/backward pass
  • lr: the learning rate
  • device: the device type, here using a GPU (which runs CUDA), not CPU

Next, you will load the MNIST dataset as tensors using a dataloader.

In [12]:
# Set your parameters
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001

# Load MNIST dataset as tensors
dataloader = DataLoader(
    MNIST('.', download=False, transform=transforms.ToTensor()),
    batch_size=batch_size,
    shuffle=True)

### DO NOT EDIT ###
device = 'cuda'

Now, you can initialize your generator, discriminator, and optimizers. Note that each optimizer only takes the parameters of one particular model, since we want each optimizer to optimize only one of the models.

In [13]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

Before you train your GAN, you will need to create functions to calculate the discriminator's loss and the generator's loss. This is how the discriminator and generator will know how they are doing and improve themselves. Since the generator is needed when calculating the discriminator's loss, you will need to call .detach() on the generator result to ensure that only the discriminator is updated!

Remember that you have already defined a loss function earlier (criterion) and you are encouraged to use torch.ones_like and torch.zeros_like instead of torch.ones or torch.zeros. If you use torch.ones or torch.zeros, you'll need to pass device=device to them.

In [14]:
# UNQ_C6 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_disc_loss
def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):
    '''
    Return the loss of the discriminator given inputs.
    Parameters:
        gen: the generator model, which returns an image given z-dimensional noise
        disc: the discriminator model, which returns a single-dimensional prediction of real/fake
        criterion: the loss function, which should be used to compare 
               the discriminator's predictions to the ground truth reality of the images 
               (e.g. fake = 0, real = 1)
        real: a batch of real images
        num_images: the number of images the generator should produce, 
                which is also the length of the real images
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    Returns:
        disc_loss: a torch scalar loss value for the current batch
    '''
    #     These are the steps you will need to complete:
    #       1) Create noise vectors and generate a batch (num_images) of fake images. 
    #            Make sure to pass the device argument to the noise.
    #       2) Get the discriminator's prediction of the fake image 
    #            and calculate the loss. Don't forget to detach the generator!
    #            (Remember the loss function you set earlier -- criterion. You need a 
    #            'ground truth' tensor in order to calculate the loss. 
    #            For example, a ground truth tensor for a fake image is all zeros.)
    #       3) Get the discriminator's prediction of the real image and calculate the loss.
    #       4) Calculate the discriminator's loss by averaging the real and fake loss
    #            and set it to disc_loss.
    #     Note: Please do not use concatenation in your solution. The tests are being updated to 
    #           support this, but for now, average the two losses as described in step (4).
    #     *Important*: You should NOT write your own loss function here - use criterion(pred, true)!
    #### START CODE HERE ####
    noise = get_noise(num_images, z_dim, device = device)
    
    fake_label = torch.zeros(num_images, 1, device = device)
    fake_image = gen(noise)
    fake_output = disc(fake_image.detach())
    loss_disc_fake = criterion(fake_output, fake_label)
    
    real_label = torch.ones(num_images, 1, device = device)
    real_output = disc(real)
    loss_disc_real = criterion(real_output, real_label)

    disc_loss = (loss_disc_real + loss_disc_fake)/2.0
    #### END CODE HERE ####
    return disc_loss
In [15]:
def test_disc_reasonable(num_images=10):
    # Don't use explicit casts to cuda - use the device argument
    import inspect, re
    lines = inspect.getsource(get_disc_loss)
    assert (re.search(r"to\(.cuda.\)", lines)) is None
    assert (re.search(r"\.cuda\(\)", lines)) is None
    
    z_dim = 64
    gen = torch.zeros_like
    disc = lambda x: x.mean(1)[:, None]
    criterion = torch.mul # Multiply
    real = torch.ones(num_images, z_dim)
    disc_loss = get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu')
    assert torch.all(torch.abs(disc_loss.mean() - 0.5) < 1e-5)
    
    gen = torch.ones_like
    criterion = torch.mul # Multiply
    real = torch.zeros(num_images, z_dim)
    assert torch.all(torch.abs(get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu')) < 1e-5)
    
    gen = lambda x: torch.ones(num_images, 10)
    disc = lambda x: x.mean(1)[:, None] + 10
    criterion = torch.mul # Multiply
    real = torch.zeros(num_images, 10)
    assert torch.all(torch.abs(get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu').mean() - 5) < 1e-5)

    gen = torch.ones_like
    disc = nn.Linear(64, 1, bias=False)
    real = torch.ones(num_images, 64) * 0.5
    disc.weight.data = torch.ones_like(disc.weight.data) * 0.5
    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
    criterion = lambda x, y: torch.sum(x) + torch.sum(y)
    disc_loss = get_disc_loss(gen, disc, criterion, real, num_images, z_dim, 'cpu').mean()
    disc_loss.backward()
    assert torch.isclose(torch.abs(disc.weight.grad.mean() - 11.25), torch.tensor(3.75))
    
def test_disc_loss(max_tests = 10):
    z_dim = 64
    gen = Generator(z_dim).to(device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
    disc = Discriminator().to(device) 
    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
    num_steps = 0
    for real, _ in dataloader:
        cur_batch_size = len(real)
        real = real.view(cur_batch_size, -1).to(device)

        ### Update discriminator ###
        # Zero out the gradient before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)
        assert (disc_loss - 0.68).abs() < 0.05

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Check that they detached correctly
        assert gen.gen[0][0].weight.grad is None

        # Update optimizer
        old_weight = disc.disc[0][0].weight.data.clone()
        disc_opt.step()
        new_weight = disc.disc[0][0].weight.data
        
        # Check that some discriminator weights changed
        assert not torch.all(torch.eq(old_weight, new_weight))
        num_steps += 1
        if num_steps >= max_tests:
            break

test_disc_reasonable()
test_disc_loss()
print("Success!")
Success!
In [16]:
# UNQ_C7 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_gen_loss
def get_gen_loss(gen, disc, criterion, num_images, z_dim, device):
    '''
    Return the loss of the generator given inputs.
    Parameters:
        gen: the generator model, which returns an image given z-dimensional noise
        disc: the discriminator model, which returns a single-dimensional prediction of real/fake
        criterion: the loss function, which should be used to compare 
               the discriminator's predictions to the ground truth reality of the images 
               (e.g. fake = 0, real = 1)
        num_images: the number of images the generator should produce, 
                which is also the length of the real images
        z_dim: the dimension of the noise vector, a scalar
        device: the device type
    Returns:
        gen_loss: a torch scalar loss value for the current batch
    '''
    #     These are the steps you will need to complete:
    #       1) Create noise vectors and generate a batch of fake images. 
    #           Remember to pass the device argument to the get_noise function.
    #       2) Get the discriminator's prediction of the fake image.
    #       3) Calculate the generator's loss. Remember the generator wants
    #          the discriminator to think that its fake images are real
    #     *Important*: You should NOT write your own loss function here - use criterion(pred, true)!

    #### START CODE HERE ####
    noise = get_noise(num_images, z_dim, device = device)
    fake_image = gen(noise)
    fake_output = disc(fake_image)    
    real_label = torch.ones(num_images, 1, device = device)
    gen_loss = criterion(fake_output, real_label)

    #### END CODE HERE ####
    return gen_loss
In [17]:
def test_gen_reasonable(num_images=10):
    # Don't use explicit casts to cuda - use the device argument
    import inspect, re
    lines = inspect.getsource(get_gen_loss)
    assert (re.search(r"to\(.cuda.\)", lines)) is None
    assert (re.search(r"\.cuda\(\)", lines)) is None
    
    z_dim = 64
    gen = torch.zeros_like
    disc = nn.Identity()
    criterion = torch.mul # Multiply
    gen_loss_tensor = get_gen_loss(gen, disc, criterion, num_images, z_dim, 'cpu')
    assert torch.all(torch.abs(gen_loss_tensor) < 1e-5)
    #Verify shape. Related to gen_noise parametrization
    assert tuple(gen_loss_tensor.shape) == (num_images, z_dim)

    gen = torch.ones_like
    disc = nn.Identity()
    criterion = torch.mul # Multiply
    real = torch.zeros(num_images, 1)
    gen_loss_tensor = get_gen_loss(gen, disc, criterion, num_images, z_dim, 'cpu')
    assert torch.all(torch.abs(gen_loss_tensor - 1) < 1e-5)
    #Verify shape. Related to gen_noise parametrization
    assert tuple(gen_loss_tensor.shape) == (num_images, z_dim)
    

def test_gen_loss(num_images):
    z_dim = 64
    gen = Generator(z_dim).to(device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
    disc = Discriminator().to(device) 
    disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)
    
    gen_loss = get_gen_loss(gen, disc, criterion, num_images, z_dim, device)
    
    # Check that the loss is reasonable
    assert (gen_loss - 0.7).abs() < 0.1
    gen_loss.backward()
    old_weight = gen.gen[0][0].weight.clone()
    gen_opt.step()
    new_weight = gen.gen[0][0].weight
    assert not torch.all(torch.eq(old_weight, new_weight))


test_gen_reasonable(10)
test_gen_loss(18)
print("Success!")
Success!

Finally, you can put everything together! For each epoch, you will process the entire dataset in batches. For every batch, you will need to update the discriminator and generator using their loss. Batches are sets of images that will be predicted on before the loss functions are calculated (instead of calculating the loss function after each image). Note that you may see a loss to be greater than 1, this is okay since binary cross entropy loss can be any positive number for a sufficiently confident wrong guess.

It’s also often the case that the discriminator will outperform the generator, especially at the start, because its job is easier. It's important that neither one gets too good (that is, near-perfect accuracy), which would cause the entire model to stop learning. Balancing the two models is actually remarkably hard to do in a standard GAN and something you will see more of in later lectures and assignments.

After you've submitted a working version with the original architecture, feel free to play around with the architecture if you want to see how different architectural choices can lead to better or worse GANs. For example, consider changing the size of the hidden dimension, or making the networks shallower or deeper by changing the number of layers.

But remember, don’t expect anything spectacular: this is only the first lesson. The results will get better with later lessons as you learn methods to help keep your generator and discriminator at similar levels.

You should roughly expect to see this progression. On a GPU, this should take about 15 seconds per 500 steps, on average, while on CPU it will take roughly 1.5 minutes: MNIST Digits

In [19]:
# UNQ_C8 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: 

cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
test_generator = True # Whether the generator should be tested
gen_loss = False
error = False
for epoch in range(n_epochs):
  
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)

        # Flatten the batch of real images from the dataset
        real = real.view(cur_batch_size, -1).to(device)

        ### Update discriminator ###
        # Zero out the gradients before backpropagation
        disc_opt.zero_grad()

        # Calculate discriminator loss
        disc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)

        # Update gradients
        disc_loss.backward(retain_graph=True)

        # Update optimizer
        disc_opt.step()

        # For testing purposes, to keep track of the generator weights
        if test_generator:
            old_generator_weights = gen.gen[0][0].weight.detach().clone()

        ### Update generator ###
        #     Hint: This code will look a lot like the discriminator updates!
        #     These are the steps you will need to complete:
        #       1) Zero out the gradients.
        #       2) Calculate the generator loss, assigning it to gen_loss.
        #       3) Backprop through the generator: update the gradients and optimizer.
        #### START CODE HERE ####
        gen_opt.zero_grad()
        gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)
        gen_loss.backward(retain_graph=True)
        gen_opt.step()
        #### END CODE HERE ####

        # For testing purposes, to check that your code changes the generator weights
        if test_generator:
            try:
                assert lr > 0.0000002 or (gen.gen[0][0].weight.grad.abs().max() < 0.0005 and epoch == 0)
                assert torch.any(gen.gen[0][0].weight.detach().clone() != old_generator_weights)
            except:
                error = True
                print("Runtime tests have failed")

        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step

        # Keep track of the average generator loss
        mean_generator_loss += gen_loss.item() / display_step

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            print(f"Epoch {epoch}, step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            show_tensor_images(fake)
            show_tensor_images(real)
            mean_generator_loss = 0
            mean_discriminator_loss = 0
        cur_step += 1

Epoch 1, step 500: Generator loss: 1.4021892887353895, discriminator loss: 0.41650343376398125

Epoch 2, step 1000: Generator loss: 1.7547867209911354, discriminator loss: 0.27905121481418593

Epoch 3, step 1500: Generator loss: 2.0534866051673886, discriminator loss: 0.16022501187026497

Epoch 4, step 2000: Generator loss: 1.6590790162086473, discriminator loss: 0.230010802447796

Epoch 5, step 2500: Generator loss: 1.6464211375713351, discriminator loss: 0.20785944911837578

Epoch 6, step 3000: Generator loss: 1.8679591503143314, discriminator loss: 0.18165835958719256

Epoch 7, step 3500: Generator loss: 2.281863901376724, discriminator loss: 0.14068569943308815

Epoch 8, step 4000: Generator loss: 2.7038585433959956, discriminator loss: 0.10645679509639738

Epoch 9, step 4500: Generator loss: 3.2159018826484695, discriminator loss: 0.07902633555233485

Epoch 10, step 5000: Generator loss: 3.4632100396156313, discriminator loss: 0.06536077416688203

Epoch 11, step 5500: Generator loss: 3.79067779636383, discriminator loss: 0.053821022659540176

Epoch 12, step 6000: Generator loss: 4.015103528022767, discriminator loss: 0.04951630744710562

Epoch 13, step 6500: Generator loss: 4.094020510196684, discriminator loss: 0.04808338173106315

Epoch 14, step 7000: Generator loss: 4.327019422054291, discriminator loss: 0.045291932318359665

Epoch 15, step 7500: Generator loss: 4.266022222042087, discriminator loss: 0.05334099601209168


Epoch 17, step 8000: Generator loss: 4.326726165294647, discriminator loss: 0.0541783450841904

Epoch 18, step 8500: Generator loss: 4.28860838031769, discriminator loss: 0.04768068019673223

Epoch 19, step 9000: Generator loss: 4.32924272823334, discriminator loss: 0.04568632438778876

Epoch 20, step 9500: Generator loss: 4.429699569702155, discriminator loss: 0.043975398533046245

Epoch 21, step 10000: Generator loss: 4.616788505554203, discriminator loss: 0.04617461739480495

Epoch 22, step 10500: Generator loss: 4.349149279594423, discriminator loss: 0.05810767061635853

Epoch 23, step 11000: Generator loss: 4.165111516475675, discriminator loss: 0.06640870022028683

Epoch 24, step 11500: Generator loss: 4.087418521404266, discriminator loss: 0.07283192903175954

Epoch 25, step 12000: Generator loss: 4.0554632601737985, discriminator loss: 0.07039332207292312

Epoch 26, step 12500: Generator loss: 3.8847192630767817, discriminator loss: 0.08408918690308918

Epoch 27, step 13000: Generator loss: 3.8577137436866775, discriminator loss: 0.08629989367723462

Epoch 28, step 13500: Generator loss: 4.055837862014769, discriminator loss: 0.06891949273273347

Epoch 29, step 14000: Generator loss: 4.088778685569762, discriminator loss: 0.08247101604938506

Epoch 30, step 14500: Generator loss: 4.167207420825958, discriminator loss: 0.08635814798623319

Epoch 31, step 15000: Generator loss: 4.051998943805697, discriminator loss: 0.0926095475703478


Epoch 33, step 15500: Generator loss: 3.906418939113617, discriminator loss: 0.10214256211370236

Epoch 34, step 16000: Generator loss: 3.8070238747596727, discriminator loss: 0.10365706507861606

Epoch 35, step 16500: Generator loss: 3.599786550045011, discriminator loss: 0.12012411704659468

Epoch 36, step 17000: Generator loss: 3.4502770071029634, discriminator loss: 0.12028309256583461

Epoch 37, step 17500: Generator loss: 3.423109201431277, discriminator loss: 0.1265528633370995

Epoch 38, step 18000: Generator loss: 3.4674059338569636, discriminator loss: 0.13159299311041833

Epoch 39, step 18500: Generator loss: 3.5074359421730033, discriminator loss: 0.12356326410919417

Epoch 40, step 19000: Generator loss: 3.3632523379325905, discriminator loss: 0.13338379342854034

Epoch 41, step 19500: Generator loss: 3.22603641366959, discriminator loss: 0.14790745033323754

Epoch 42, step 20000: Generator loss: 3.2789574327468887, discriminator loss: 0.15603232999145966

Epoch 43, step 20500: Generator loss: 3.3287756581306427, discriminator loss: 0.1399594063609839

Epoch 44, step 21000: Generator loss: 3.1529601063728347, discriminator loss: 0.1596593412756919

Epoch 45, step 21500: Generator loss: 3.1682705559730517, discriminator loss: 0.15389567883312702

Epoch 46, step 22000: Generator loss: 3.400409095287324, discriminator loss: 0.1284988962113859

Epoch 47, step 22500: Generator loss: 3.4216985025405884, discriminator loss: 0.14327448347210894


Epoch 49, step 23000: Generator loss: 3.226895459651946, discriminator loss: 0.16510137757658952

Epoch 50, step 23500: Generator loss: 3.0258228740692115, discriminator loss: 0.17053525249660015

Epoch 51, step 24000: Generator loss: 3.1062900247573846, discriminator loss: 0.14867938023060578

Epoch 52, step 24500: Generator loss: 3.117117701053616, discriminator loss: 0.1576450464874505

Epoch 53, step 25000: Generator loss: 2.9933766942024227, discriminator loss: 0.16798575283586986

Epoch 54, step 25500: Generator loss: 3.1737692122459413, discriminator loss: 0.14725146122276778

Epoch 55, step 26000: Generator loss: 3.1020456552505484, discriminator loss: 0.16312602414190774

Epoch 56, step 26500: Generator loss: 3.037220055580141, discriminator loss: 0.1679544922858477

Epoch 57, step 27000: Generator loss: 2.886235658168792, discriminator loss: 0.18317969346046453

Epoch 58, step 27500: Generator loss: 3.0850545754432677, discriminator loss: 0.16010782542824728

Epoch 59, step 28000: Generator loss: 2.9465291552543658, discriminator loss: 0.18082491309940832

Epoch 60, step 28500: Generator loss: 3.1555335435867304, discriminator loss: 0.16766628864407543

Epoch 61, step 29000: Generator loss: 2.9077837586402913, discriminator loss: 0.18663176983594884

Epoch 62, step 29500: Generator loss: 2.9999897756576543, discriminator loss: 0.16929173508286488

Epoch 63, step 30000: Generator loss: 3.004062754154207, discriminator loss: 0.19283815248310573


Epoch 65, step 30500: Generator loss: 2.8121536002159115, discriminator loss: 0.19729897643625743

Epoch 66, step 31000: Generator loss: 2.7960010857582063, discriminator loss: 0.1912673514634372

Epoch 67, step 31500: Generator loss: 2.7313903932571404, discriminator loss: 0.2076583785861729

Epoch 68, step 32000: Generator loss: 2.7977543506622324, discriminator loss: 0.1971935890763998

Epoch 69, step 32500: Generator loss: 2.736099233150484, discriminator loss: 0.2185090746730565

Epoch 70, step 33000: Generator loss: 2.4949338834285726, discriminator loss: 0.23879047545790674

Epoch 71, step 33500: Generator loss: 2.536931770801545, discriminator loss: 0.24109960806369785

Epoch 72, step 34000: Generator loss: 2.5381627283096315, discriminator loss: 0.2321193739771842

Epoch 73, step 34500: Generator loss: 2.6110821628570524, discriminator loss: 0.21891257108747944

Epoch 74, step 35000: Generator loss: 2.49381052017212, discriminator loss: 0.23904302278161058

Epoch 75, step 35500: Generator loss: 2.396353168725965, discriminator loss: 0.25981025123596196

Epoch 76, step 36000: Generator loss: 2.415967663288117, discriminator loss: 0.2549662430286407

Epoch 77, step 36500: Generator loss: 2.5166643891334535, discriminator loss: 0.22039907833933844

Epoch 78, step 37000: Generator loss: 2.414042107343672, discriminator loss: 0.2519847467839718

Epoch 79, step 37500: Generator loss: 2.4786592149734497, discriminator loss: 0.2480662245750427


Epoch 81, step 38000: Generator loss: 2.420161560773848, discriminator loss: 0.2525256572365759

Epoch 82, step 38500: Generator loss: 2.309714781045914, discriminator loss: 0.268059366673231

Epoch 83, step 39000: Generator loss: 2.3388535225391385, discriminator loss: 0.2575041563808917

Epoch 84, step 39500: Generator loss: 2.349515162467955, discriminator loss: 0.2593803414106369

Epoch 85, step 40000: Generator loss: 2.4101956586837785, discriminator loss: 0.24844183737039569

Epoch 86, step 40500: Generator loss: 2.2764426503181445, discriminator loss: 0.2687814032435418

Epoch 87, step 41000: Generator loss: 2.3060552699565884, discriminator loss: 0.2555863273441791

Epoch 88, step 41500: Generator loss: 2.2887331538200395, discriminator loss: 0.25911325177550326

Epoch 89, step 42000: Generator loss: 2.118716900348665, discriminator loss: 0.29184505426883717

Epoch 90, step 42500: Generator loss: 2.212179771900177, discriminator loss: 0.2678806773722174

Epoch 91, step 43000: Generator loss: 2.1505995123386357, discriminator loss: 0.2951893527507784

Epoch 92, step 43500: Generator loss: 2.07681686425209, discriminator loss: 0.29788028573989866

Epoch 93, step 44000: Generator loss: 2.1545186591148364, discriminator loss: 0.27460581994056693

Epoch 94, step 44500: Generator loss: 2.243549757719041, discriminator loss: 0.26621534839272487

Epoch 95, step 45000: Generator loss: 2.1012500779628738, discriminator loss: 0.29820541495084735


Epoch 97, step 45500: Generator loss: 2.070566725969312, discriminator loss: 0.30606148165464403

Epoch 98, step 46000: Generator loss: 2.0449455878734595, discriminator loss: 0.3045798503756522

Epoch 99, step 46500: Generator loss: 2.151355919122697, discriminator loss: 0.27931378784775723

Epoch 100, step 47000: Generator loss: 2.113484380960464, discriminator loss: 0.31309118512272804

Epoch 101, step 47500: Generator loss: 2.0569242799282073, discriminator loss: 0.3054923801422118

Epoch 102, step 48000: Generator loss: 1.9198945467472066, discriminator loss: 0.34593918418884306

Epoch 103, step 48500: Generator loss: 1.8925537531375891, discriminator loss: 0.3277647721767428

Epoch 104, step 49000: Generator loss: 1.9225851898193356, discriminator loss: 0.3200696520209311

Epoch 105, step 49500: Generator loss: 1.963495755434035, discriminator loss: 0.314146119207144

Epoch 106, step 50000: Generator loss: 2.0160290994644177, discriminator loss: 0.29577211722731583

Epoch 107, step 50500: Generator loss: 2.020676573514938, discriminator loss: 0.3141472573876378

Epoch 108, step 51000: Generator loss: 1.899706240415572, discriminator loss: 0.3386503512263297

Epoch 109, step 51500: Generator loss: 1.9027093632221215, discriminator loss: 0.3201370184421537

Epoch 110, step 52000: Generator loss: 2.0707068309783914, discriminator loss: 0.29698964014649365

Epoch 111, step 52500: Generator loss: 2.104400192499162, discriminator loss: 0.2886533603966238


Epoch 113, step 53000: Generator loss: 2.030498737812044, discriminator loss: 0.3075546867251394

Epoch 114, step 53500: Generator loss: 2.0020633275508883, discriminator loss: 0.3170958526730537

Epoch 115, step 54000: Generator loss: 1.9024381937980663, discriminator loss: 0.328866899549961

Epoch 116, step 54500: Generator loss: 1.9211840703487397, discriminator loss: 0.3142661949992181

Epoch 117, step 55000: Generator loss: 1.834452604770661, discriminator loss: 0.3390869683623314

Epoch 118, step 55500: Generator loss: 1.7915789806842803, discriminator loss: 0.3466750149130823

Epoch 119, step 56000: Generator loss: 1.862494437456131, discriminator loss: 0.3343907089531418

Epoch 120, step 56500: Generator loss: 1.7601128582954413, discriminator loss: 0.36980430704355227

Epoch 121, step 57000: Generator loss: 1.6228554217815416, discriminator loss: 0.3854625148773199

Epoch 122, step 57500: Generator loss: 1.7724433429241175, discriminator loss: 0.3478498765528205

Epoch 123, step 58000: Generator loss: 1.8371721441745748, discriminator loss: 0.3426912147402763

Epoch 124, step 58500: Generator loss: 1.7776524059772478, discriminator loss: 0.3560112955570219

Epoch 125, step 59000: Generator loss: 1.7315826637744898, discriminator loss: 0.36849771958589517

Epoch 126, step 59500: Generator loss: 1.7255405280590066, discriminator loss: 0.37416562885045995

Epoch 127, step 60000: Generator loss: 1.7528024375438687, discriminator loss: 0.355661125540733

Epoch 128, step 60500: Generator loss: 1.7974419927597045, discriminator loss: 0.3574418122768404


Epoch 130, step 61000: Generator loss: 1.6283139948844911, discriminator loss: 0.3850496157407761

Epoch 131, step 61500: Generator loss: 1.7279877457618726, discriminator loss: 0.37991200441122075

Epoch 132, step 62000: Generator loss: 1.5958705854415884, discriminator loss: 0.3987177916169162

Epoch 133, step 62500: Generator loss: 1.6742140104770653, discriminator loss: 0.3684275991320606

Epoch 134, step 63000: Generator loss: 1.6141128222942345, discriminator loss: 0.38333785903453804

Epoch 135, step 63500: Generator loss: 1.6564068093299862, discriminator loss: 0.36808489978313447

Epoch 136, step 64000: Generator loss: 1.6705170133113854, discriminator loss: 0.37673580169677723

Epoch 137, step 64500: Generator loss: 1.581768488645553, discriminator loss: 0.3920829309821131

Epoch 138, step 65000: Generator loss: 1.6259170694351204, discriminator loss: 0.3887314875125885

Epoch 139, step 65500: Generator loss: 1.6208510766029376, discriminator loss: 0.39126033759117157

Epoch 140, step 66000: Generator loss: 1.584657865047455, discriminator loss: 0.3988079749345776

Epoch 141, step 66500: Generator loss: 1.4626513309478775, discriminator loss: 0.4404657765626902

Epoch 142, step 67000: Generator loss: 1.4796652994155886, discriminator loss: 0.4258871302008631

Epoch 143, step 67500: Generator loss: 1.426245699644089, discriminator loss: 0.4490515897870066

Epoch 144, step 68000: Generator loss: 1.4654909095764173, discriminator loss: 0.43894404798746084


Epoch 146, step 68500: Generator loss: 1.4531994416713718, discriminator loss: 0.4355195807218549

Epoch 147, step 69000: Generator loss: 1.4415548803806306, discriminator loss: 0.4324091325402259

Epoch 148, step 69500: Generator loss: 1.4597531237602248, discriminator loss: 0.4287781555056571

Epoch 149, step 70000: Generator loss: 1.448197922229766, discriminator loss: 0.43092150253057493

Epoch 150, step 70500: Generator loss: 1.3874552814960486, discriminator loss: 0.4529860781431197

Epoch 151, step 71000: Generator loss: 1.3375571987628954, discriminator loss: 0.46342803007364297

Epoch 152, step 71500: Generator loss: 1.2715160593986519, discriminator loss: 0.4821062604784966

Epoch 153, step 72000: Generator loss: 1.3931213700771348, discriminator loss: 0.4520451138019559

Epoch 154, step 72500: Generator loss: 1.3563301646709451, discriminator loss: 0.4571546506285664

Epoch 155, step 73000: Generator loss: 1.406058851957321, discriminator loss: 0.4381275774240499

Epoch 156, step 73500: Generator loss: 1.4404574294090269, discriminator loss: 0.43083640855550814

Epoch 157, step 74000: Generator loss: 1.3826279881000523, discriminator loss: 0.4408636956810945

Epoch 158, step 74500: Generator loss: 1.381461824893949, discriminator loss: 0.43969752699136755

Epoch 159, step 75000: Generator loss: 1.3752208452224721, discriminator loss: 0.44353364700078995

Epoch 160, step 75500: Generator loss: 1.459187427043915, discriminator loss: 0.41922533226013214


Epoch 162, step 76000: Generator loss: 1.398020961999894, discriminator loss: 0.4378930845856662

Epoch 163, step 76500: Generator loss: 1.349242374897002, discriminator loss: 0.4510687167644505

Epoch 164, step 77000: Generator loss: 1.3467774083614361, discriminator loss: 0.4440856617093086

Epoch 165, step 77500: Generator loss: 1.2900197572708139, discriminator loss: 0.466754495799542

Epoch 166, step 78000: Generator loss: 1.2697726249694845, discriminator loss: 0.4884453754425049

Epoch 167, step 78500: Generator loss: 1.2812501013278965, discriminator loss: 0.4700217588543891

Epoch 168, step 79000: Generator loss: 1.2389389932155594, discriminator loss: 0.50324834638834

Epoch 169, step 79500: Generator loss: 1.2159292020797727, discriminator loss: 0.49224082231521626

Epoch 170, step 80000: Generator loss: 1.2092946174144728, discriminator loss: 0.5077563726902009

Epoch 171, step 80500: Generator loss: 1.2058949604034417, discriminator loss: 0.4964593704938888

Epoch 172, step 81000: Generator loss: 1.244963624238967, discriminator loss: 0.48030589306354565

Epoch 173, step 81500: Generator loss: 1.2427240750789643, discriminator loss: 0.48604615771770465

Epoch 174, step 82000: Generator loss: 1.2228716449737547, discriminator loss: 0.501407857835293

Epoch 175, step 82500: Generator loss: 1.293818171739579, discriminator loss: 0.4723715763092041

Epoch 176, step 83000: Generator loss: 1.2332793104648572, discriminator loss: 0.4825868537425993


Epoch 178, step 83500: Generator loss: 1.1625886285305012, discriminator loss: 0.5163247696757315

Epoch 179, step 84000: Generator loss: 1.2023157367706305, discriminator loss: 0.4875065478682521

Epoch 180, step 84500: Generator loss: 1.2246280994415282, discriminator loss: 0.48611231678724276

Epoch 181, step 85000: Generator loss: 1.216065018653869, discriminator loss: 0.49624158513545974

Epoch 182, step 85500: Generator loss: 1.1837934172153461, discriminator loss: 0.5043423801660538

Epoch 183, step 86000: Generator loss: 1.1905218837261202, discriminator loss: 0.4959087416529657

Epoch 184, step 86500: Generator loss: 1.1908725444078454, discriminator loss: 0.4932625278830527

Epoch 185, step 87000: Generator loss: 1.203426664113998, discriminator loss: 0.5009834297299384

Epoch 186, step 87500: Generator loss: 1.26648530292511, discriminator loss: 0.47784355676174145

Epoch 187, step 88000: Generator loss: 1.14257006919384, discriminator loss: 0.5099722841382033

Epoch 188, step 88500: Generator loss: 1.128186732172966, discriminator loss: 0.5059958605170247

Epoch 189, step 89000: Generator loss: 1.1594977698326114, discriminator loss: 0.4988324818015095

Epoch 190, step 89500: Generator loss: 1.1368910439014444, discriminator loss: 0.5116511986851688

Epoch 191, step 90000: Generator loss: 1.126564669251442, discriminator loss: 0.5098570914268493

Epoch 192, step 90500: Generator loss: 1.2073729650974288, discriminator loss: 0.49162136960029623


Epoch 194, step 91000: Generator loss: 1.1679830052852636, discriminator loss: 0.49945063894987085

Epoch 195, step 91500: Generator loss: 1.1295934059619908, discriminator loss: 0.5254922946691516

Epoch 196, step 92000: Generator loss: 1.136168262362481, discriminator loss: 0.5114749006628992

Epoch 197, step 92500: Generator loss: 1.1145593265295026, discriminator loss: 0.5145469736456869

Epoch 198, step 93000: Generator loss: 1.0910029145479208, discriminator loss: 0.5341875038146977

Epoch 199, step 93500: Generator loss: 1.0984639643430714, discriminator loss: 0.5258551649451262

In [ ]: